-
Notifications
You must be signed in to change notification settings - Fork 817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(srt): support prefill and generate with input_embeds
#2082
Conversation
input_embeds
input_embeds
857750a
to
8058e22
Compare
Thanks for the contribution. There is a related PR recently. Can you take a review on that? #2052 |
976f2c0
to
98331c0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! I left a few comments.
Can you add a test case for llama and llava?
from enum import Enum | ||
from typing import Dict, List, Optional, Union | ||
|
||
from sglang.srt.managers.schedule_batch import BaseFinishReason | ||
from sglang.srt.sampling.sampling_params import SamplingParams | ||
|
||
# Use sequence instead of Tensor here because Pydantic serializes Python objects |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sequence
or list
?
if sys.version_info >= (3, 10): | ||
_: dataclasses.KW_ONLY |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this used for?
@@ -430,6 +435,9 @@ def __repr__(self): | |||
class ScheduleBatch: | |||
"""Store all inforamtion of a batch.""" | |||
|
|||
if sys.version_info >= (3, 10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to get rid of this?
@@ -876,7 +902,7 @@ def check_for_jump_forward(self, pad_input_ids_func): | |||
jump_forward_reqs.append(req) | |||
keep_indices.remove(i) | |||
|
|||
self.filter_batch(keep_indices=list(keep_indices)) | |||
self.filter_batch(keep_indices=sorted(keep_indices)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is sorted better here?
( | ||
logits_output, | ||
next_token_ids, | ||
next_token_embeds, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need next_token_embeds
? I think after the first prefill, we can use token ids and do not need to take embedding inputs anymore.
( | ||
logits_output, | ||
next_token_ids, | ||
next_token_embeds, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think next_token_embeds
is probably not necessary here. It makes things much more complicated.
Some of your handling here is not correct as you need to handle the copy of them correctly.
Ideally, we can get rid of next_token_embeds
and do not need to change this file.
@@ -211,6 +218,11 @@ def init_new( | |||
forward_mode=batch.forward_mode, | |||
batch_size=len(batch.seq_lens), | |||
input_ids=batch.input_ids, | |||
input_embeds=( | |||
batch.input_embeds.clone().detach().to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get rid of this extra copy?
def forward_decode(self, forward_batch: ForwardBatch): | ||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): | ||
return self.cuda_graph_runner.replay(forward_batch) | ||
|
||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) | ||
self.attn_backend.init_forward_metadata(forward_batch) | ||
|
||
if forward_batch.input_embeds is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably only need this input_embeds
for prefill.
I feel #2052 is probably a cleaner solution. |
98331c0
to
62e3104
Compare
…infer` is not installed
… generation mode
62e3104
to
4e28940
Compare
will close this one in favor of #2052 |
Motivation
Resolves #745
Modifications
As per the commit messages.
Checklist